import pandas as pd
import numpy as np
import statsmodels.api as sm
import statsmodels.formula.api as smf
import math
import sys
sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import * 
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import weightedBinscatter as wb
import estimators as et
#sys.path.append('packages')
#import binscatter
import matplotlib.font_manager as fm
import matplotlib
import statsmodels.formula.api as smf
from statsmodels.iolib.smpickle import load_pickle
import joblib
from trajectoryPlotters import *
import statsmodels.api as sm

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']


## linear estimator
def regEstimate(dataset, pbvarb, outcome, wvarb):
    model = smf.wls(outcome + '~' + pbvarb, dataset, weights = dataset[wvarb]).fit(cov_type = 'HC1')
    coef = model.params[pbvarb]
    se = model.bse[pbvarb]
    black_aud_rate = model.params[0] + model.params[1]
    non_black_aud_rate = model.params[0]
    return coef,se,black_aud_rate, non_black_aud_rate


## probabilistic estimator
def chenEstimate(dataset, pbvarb, outcome, wvarb):
    black_audit_rate = (dataset[pbvarb]*dataset[outcome]*dataset[wvarb]).sum()/(dataset[pbvarb]*dataset[wvarb]).sum()
    nonblack_audit_rate = ((1-dataset[pbvarb])*dataset[outcome]*dataset[wvarb]).sum()/((1-dataset[pbvarb])*dataset[wvarb]).sum()
    est = black_audit_rate - nonblack_audit_rate
    return est, black_audit_rate, nonblack_audit_rate

# read in csvs outputted during model runs for refundable oracle, refundable regressor, underreportaxpayer_idg oracle, and underreportaxpayer_idg regressor
ref_oracle = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids_alt_outcome.csv')
ref_reg = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids_alt_outcome.csv')    
oracle = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids.csv')
reg = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids.csv') 

oracle_results = chenEstimate(oracle, 'predicted_prob_black', 'aud_ind', 'base_weight')
reg_results = chenEstimate(reg, 'predicted_prob_black', 'aud_ind', 'base_weight')
ref_oracle_results = chenEstimate(ref_oracle, 'predicted_prob_black', 'aud_ind', 'base_weight')
ref_reg_results = chenEstimate(ref_reg, 'predicted_prob_black', 'aud_ind', 'base_weight')


keep_cols = ['taxpayer_id', 'predicted_prob_black','aud_no_research_audits', 'base_wgt', 'isEIC']
individual2014 = pd.read_csv('/REDACTED/data/final/individualBISG2014_full_final.csv', usecols=keep_cols)
eic = individual2014[individual2014.isEIC == 1]
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
        non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
        est =  black_aud_rate - non_black_aud_rate
        return est, black_aud_rate, non_black_aud_rate

sq_results = chenEstimate(eic, 'predicted_prob_black', 'aud_no_research_audits')

black_results = [sq_results[1], oracle_results[1], reg_results[1], ref_oracle_results[1], ref_reg_results[1]]
nonblack_results = [sq_results[2], oracle_results[2], reg_results[2], ref_oracle_results[2], ref_reg_results[2]]

results_dict = {'black_ar':black_results, 'nonblack_ar':nonblack_results}

df = pd.DataFrame(results_dict)
models = ('Status Quo', 'Total Underreportaxpayer_idg Oracle', 'Total Underreportaxpayer_idg Prediction', 'Refundable Credit Oracle', 'Refundable Credit Prediction')
x = np.arange(len(models))

def labelbar(ax,rects,nsig=2,fontsize=16,additiveeps=0):
    for rect in rects:
        height = rect.get_height()
        sigfigstr = '%.'+str(nsig)+'f'
        print(sigfigstr)
        ax.text(rect.get_x() + rect.get_width()/2,1.005*height+additiveeps, sigfigstr % (height),ha='center',va='bottom',fontsize=fontsize)


import matplotlib.pyplot as plt
ax1 = plt.subplot()

bars_nb = ax1.bar(x, df.nonblack_ar*100, width=0.25, label='Non-Black Audit Rate', color = '#53284f', alpha = 0.4)
bars_b = ax1.bar(x+.25, df.black_ar*100, width=0.25, label='Black Audit Rate', color = '#53284f', alpha = 1)
ax1.set_xticks(x + .125)
ax1.set_xticklabels(models, rotation=30, ha="right")
ax1.legend()
ax1.set_ylabel('Audit Rate (probabilistic, percentage points)')
labelbar(ax1, bars_nb)
labelbar(ax1, bars_b)


plt.tight_layout()
plt.savefig('/REDACTED/bad_test.png')
plt.close()

df['disp'] = df['black_ar'] - df['nonblack_ar']

defaultaxpayer_id= '/REDACTED/data/modeled_refactor_temp/'

with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    oracle = pickle.load(f)

point_full = float(oracle['1.45p_mean_full_chen_fair_unres_oracle'])
point_boot = float(oracle['1.45p_mean_bootstrap_chen_fair_unres_oracle'][0])

with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    ref_oracle = pickle.load(f)

point_full = float(ref_oracle['1.45p_mean_full_chen_fair_unres_oracle'])
point_boot = float(ref_oracle['1.45p_mean_bootstrap_chen_fair_unres_oracle'][0])

with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    reg = pickle.load(f)

point_full = float(reg['1.45p_mean_full_chen_fair_unres_reg'])
point_boot = float(reg['1.45p_mean_bootstrap_chen_fair_unres_reg'][0])

with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    ref_reg = pickle.load(f)

point_full = float(ref_reg['1.45p_mean_full_chen_fair_unres_reg'])
point_boot = float(ref_reg['1.45p_mean_bootstrap_chen_fair_unres_reg'][0])

## Status Quo
with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    sq = pickle.load(f)

sq_point = sq['eitc_chen_fair']

#################### linear results

# read in csvs outputted during model runs for refundable oracle, refundable regressor, underreportaxpayer_idg oracle, and underreportaxpayer_idg regressor
ref_oracle = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids_alt_outcome.csv')
ref_reg = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids_alt_outcome.csv')    
oracle = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_oracle_selected_taxpayer_ids.csv')
reg = pd.read_csv('/REDACTED/data/modeled_refactor_temp/unres_reg_selected_taxpayer_ids.csv') 


oracle_results = regEstimate(oracle, 'predicted_prob_black', 'aud_ind', 'base_weight')
reg_results = regEstimate(reg, 'predicted_prob_black', 'aud_ind', 'base_weight')
ref_oracle_results = regEstimate(ref_oracle, 'predicted_prob_black', 'aud_ind', 'base_weight')
ref_reg_results = regEstimate(ref_reg, 'predicted_prob_black', 'aud_ind', 'base_weight')

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
        model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
        #model =  smf.ols(outcome + ' ~ ' + pbvarb, dataset).fit(cov_type='HC1')
        coef = model.params[pbvarb]
        se = model.bse[pbvarb]
        black_aud_rate = model.params[0] + model.params[1]
        non_black_aud_rate = model.params[0] 
        print("number of obs :" + str(model.nobs))
        return coef, se, black_aud_rate, non_black_aud_rate

sq_results = regEstimate(eic, 'predicted_prob_black', 'aud_no_research_audits')

black_results = [sq_results[2], oracle_results[2], reg_results[2], ref_oracle_results[2], ref_reg_results[2]]
nonblack_results = [sq_results[3], oracle_results[3], reg_results[3], ref_oracle_results[3], ref_reg_results[3]]

results_dict = {'black_ar':black_results, 'nonblack_ar':nonblack_results}

df = pd.DataFrame(results_dict)
models = ('Status Quo', 'Total Underreportaxpayer_idg Oracle', 'Total Underreportaxpayer_idg Prediction', 'Refundable Credit Oracle', 'Refundable Credit Prediction')
x = np.arange(len(models))

def labelbar(ax,rects,nsig=2,fontsize=16,additiveeps=0):
    for rect in rects:
        height = rect.get_height()
        sigfigstr = '%.'+str(nsig)+'f'
        print(sigfigstr)
        ax.text(rect.get_x() + rect.get_width()/2,1.005*height+additiveeps, sigfigstr % (height),ha='center',va='bottom',fontsize=fontsize)


import matplotlib.pyplot as plt
ax1 = plt.subplot()

bars_nb = ax1.bar(x, df.nonblack_ar*100, width=0.25, label='Non-Black Audit Rate', color = '#53284f', alpha = 0.4)
bars_b = ax1.bar(x+.25, df.black_ar*100, width=0.25, label='Black Audit Rate', color = '#53284f', alpha = 1)
ax1.set_xticks(x + .125)
ax1.set_xticklabels(models, rotation=30, ha="right")
ax1.legend()
ax1.set_ylabel('Audit Rate (linear, percentage points)')
labelbar(ax1, bars_nb)
labelbar(ax1, bars_b)


plt.tight_layout()
plt.savefig('/REDACTED/bad_test_lin.png')
plt.close()

df['disp'] = df['black_ar'] - df['nonblack_ar']

defaultaxpayer_id= '/REDACTED/data/modeled_refactor_temp/'

with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    oracle = pickle.load(f)

point_full = float(oracle['1.45p_mean_full_reg_fair_unres_oracle'])
point_boot = float(oracle['1.45p_mean_bootstrap_reg_fair_unres_oracle'][0])

with open(defaultaxpayer_id + 'eitc_unres_oracle_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    ref_oracle = pickle.load(f)

point_full = float(ref_oracle['1.45p_mean_full_reg_fair_unres_oracle'])
point_boot = float(ref_oracle['1.45p_mean_bootstrap_reg_fair_unres_oracle'][0])

with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_new_bifsg.pickle', 'rb') as f:
    reg = pickle.load(f)

point_full = float(reg['1.45p_mean_full_reg_fair_unres_reg'])
point_boot = float(reg['1.45p_mean_bootstrap_reg_fair_unres_reg'][0])

with open(defaultaxpayer_id + 'eitc_unres_reg_plus_dep_database_output_outcome_ref_cred_amt_dif_pv_new_bifsg.pickle', 'rb') as f:
    ref_reg = pickle.load(f)

point_full = float(ref_reg['1.45p_mean_full_reg_fair_unres_reg'])
point_boot = float(ref_reg['1.45p_mean_bootstrap_reg_fair_unres_reg'][0])

## Status Quo
with open('/REDACTED/fairness/code/rf/data/status_quo.pickle','rb') as f:
    sq = pickle.load(f)

sq_point = sq['eitc_reg_fair']







